{
  "cells": [
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "ur8xi4C7S06n"
      },
      "outputs": [],
      "source": [
        "# Copyright 2024 Google LLC\n",
        "#\n",
        "# Licensed under the Apache License, Version 2.0 (the \"License\");\n",
        "# you may not use this file except in compliance with the License.\n",
        "# You may obtain a copy of the License at\n",
        "#\n",
        "#     https://www.apache.org/licenses/LICENSE-2.0\n",
        "#\n",
        "# Unless required by applicable law or agreed to in writing, software\n",
        "# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
        "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
        "# See the License for the specific language governing permissions and\n",
        "# limitations under the License."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "JAPoU8Sm5E6e"
      },
      "source": [
        "# Text Classification with Generative Models on Vertex AI\n",
        "\n",
        "\n",
        "<table align=\"left\">\n",
        "  <td style=\"text-align: center\">\n",
        "    <a href=\"https://colab.research.google.com/github/GoogleCloudPlatform/generative-ai/blob/main/gemini/prompts/examples/text_classification.ipynb\">\n",
        "      <img width=\"32px\" src=\"https://www.gstatic.com/pantheon/images/bigquery/welcome_page/colab-logo.svg\" alt=\"Google Colaboratory logo\"><br> Open in Colab\n",
        "    </a>\n",
        "  </td>\n",
        "  <td style=\"text-align: center\">\n",
        "    <a href=\"https://console.cloud.google.com/vertex-ai/colab/import/https:%2F%2Fraw.githubusercontent.com%2FGoogleCloudPlatform%2Fgenerative-ai%2Fmain%2Fgemini%2Fprompts%2Fexamples%2Ftext_classification.ipynb\">\n",
        "      <img width=\"32px\" src=\"https://lh3.googleusercontent.com/JmcxdQi-qOpctIvWKgPtrzZdJJK-J3sWE1RsfjZNwshCFgE_9fULcNpuXYTilIR2hjwN\" alt=\"Google Cloud Colab Enterprise logo\"><br> Open in Colab Enterprise\n",
        "    </a>\n",
        "  </td>    \n",
        "  <td style=\"text-align: center\">\n",
        "    <a href=\"https://console.cloud.google.com/vertex-ai/workbench/deploy-notebook?download_url=https://raw.githubusercontent.com/GoogleCloudPlatform/generative-ai/main/gemini/prompts/examples/text_classification.ipynb\">\n",
        "      <img src=\"https://lh3.googleusercontent.com/UiNooY4LUgW_oTvpsNhPpQzsstV5W8F7rYgxgGBD85cWJoLmrOzhVs_ksK_vgx40SHs7jCqkTkCk=e14-rj-sc0xffffff-h130-w32\" alt=\"Vertex AI logo\"><br> Open in Workbench\n",
        "    </a>\n",
        "  </td>\n",
        "  <td style=\"text-align: center\">\n",
        "    <a href=\"https://github.com/GoogleCloudPlatform/generative-ai/blob/main/gemini/prompts/examples/text_classification.ipynb\">\n",
        "      <img width=\"32px\" src=\"https://raw.githubusercontent.com/primer/octicons/refs/heads/main/icons/mark-github-24.svg\" alt=\"GitHub logo\"><br> View on GitHub\n",
        "    </a>\n",
        "  </td>\n",
        "</table>\n",
        "\n",
        "<div style=\"clear: both;\"></div>\n",
        "\n",
        "<b>Share to:</b>\n",
        "\n",
        "<a href=\"https://www.linkedin.com/sharing/share-offsite/?url=https%3A//github.com/GoogleCloudPlatform/generative-ai/blob/main/gemini/prompts/examples/text_classification.ipynb\" target=\"_blank\">\n",
        "  <img width=\"20px\" src=\"https://upload.wikimedia.org/wikipedia/commons/8/81/LinkedIn_icon.svg\" alt=\"LinkedIn logo\">\n",
        "</a>\n",
        "\n",
        "<a href=\"https://bsky.app/intent/compose?text=https%3A//github.com/GoogleCloudPlatform/generative-ai/blob/main/gemini/prompts/examples/text_classification.ipynb\" target=\"_blank\">\n",
        "  <img width=\"20px\" src=\"https://upload.wikimedia.org/wikipedia/commons/7/7a/Bluesky_Logo.svg\" alt=\"Bluesky logo\">\n",
        "</a>\n",
        "\n",
        "<a href=\"https://twitter.com/intent/tweet?url=https%3A//github.com/GoogleCloudPlatform/generative-ai/blob/main/gemini/prompts/examples/text_classification.ipynb\" target=\"_blank\">\n",
        "  <img width=\"20px\" src=\"https://upload.wikimedia.org/wikipedia/commons/5/5a/X_icon_2.svg\" alt=\"X logo\">\n",
        "</a>\n",
        "\n",
        "<a href=\"https://reddit.com/submit?url=https%3A//github.com/GoogleCloudPlatform/generative-ai/blob/main/gemini/prompts/examples/text_classification.ipynb\" target=\"_blank\">\n",
        "  <img width=\"20px\" src=\"https://redditinc.com/hubfs/Reddit%20Inc/Brand/Reddit_Logo.png\" alt=\"Reddit logo\">\n",
        "</a>\n",
        "\n",
        "<a href=\"https://www.facebook.com/sharer/sharer.php?u=https%3A//github.com/GoogleCloudPlatform/generative-ai/blob/main/gemini/prompts/examples/text_classification.ipynb\" target=\"_blank\">\n",
        "  <img width=\"20px\" src=\"https://upload.wikimedia.org/wikipedia/commons/5/51/Facebook_f_logo_%282019%29.svg\" alt=\"Facebook logo\">\n",
        "</a>            \n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "PpFdtPxGhVST"
      },
      "source": [
        "| Authors |\n",
        "| --- |\n",
        "| [Polong Lin](https://github.com/polong-lin) |\n",
        "| [Deepak Moonat](https://github.com/dmoonat) |"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "tvgnzT1CKxrO"
      },
      "source": [
        "## Overview\n",
        "\n",
        "Generative models like Gemini are powerful language models used for various natural language processing (NLP) tasks. One of those is text classification, which involves assigning one or more categories to a given piece of text. Although text classification can be done using traditional NLP techniques, LLMs can perform classification by providing prompts (as opposed to domain-specific labeled data), which can accelerate the time it takes to build a text classification solution. Classification models based on LLMs can be further tuned with many examples via custom model training, but that is beyond the scope of this notebook.\n",
        "\n",
        "In this notebook, you will explore how to do text classification using prompts with the Gemini API. Learn more about classification prompts in the [official documentation](https://cloud.google.com/vertex-ai/docs/generative-ai/text/classification-prompts)."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "d975e698c9a4"
      },
      "source": [
        "### Objective\n",
        "\n",
        "By the end of the notebook, you should be able to use a large language model to perform various classification tasks, including:\n",
        "\n",
        "* Zero-shot prompting text classification\n",
        "* Few-shot prompting text classification\n",
        "* Common tasks:\n",
        "    * Sentiment analysis\n",
        "    * Topic classification\n",
        "    * Spam detection\n",
        "    * Intent recognition\n",
        "    * Language identification\n",
        "    * Toxicity detection\n",
        "    * Emotion detection"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "nSyXtwyz_o_v"
      },
      "source": [
        "## Getting Started"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "2a5AEr0lkLKD"
      },
      "source": [
        "### Install Vertex AI SDK and other required packages"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "82ad0c445061"
      },
      "outputs": [],
      "source": [
        "%pip install google-cloud-aiplatform --upgrade -q\n",
        "%pip install pandas scikit-learn -q"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Xe7OuYuGkLKF"
      },
      "source": [
        "### Authenticate your notebook environment (Colab only)\n",
        "\n",
        "If you are running this notebook on Google Colab, run the cell below to authenticate your environment.\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "U9Gx2SAZkLKF"
      },
      "outputs": [],
      "source": [
        "import sys\n",
        "\n",
        "if \"google.colab\" in sys.modules:\n",
        "    from google.colab import auth\n",
        "\n",
        "    auth.authenticate_user()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "5RVGmd8dhVSV"
      },
      "source": [
        "### Set Google Cloud project information and initialize Vertex AI SDK\n",
        "\n",
        "To get started using Vertex AI, you must have an existing Google Cloud project and [enable the Vertex AI API](https://console.cloud.google.com/flows/enableapi?apiid=aiplatform.googleapis.com).\n",
        "\n",
        "Learn more about [setting up a project and a development environment](https://cloud.google.com/vertex-ai/docs/start/cloud-environment).\"\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "AhDbnhB0hVSV"
      },
      "outputs": [],
      "source": [
        "PROJECT_ID = \"your-project-id\"  # @param {type:\"string\"}\n",
        "LOCATION = \"us-central1\"  # @param {type:\"string\"}\n",
        "\n",
        "import vertexai\n",
        "\n",
        "vertexai.init(project=PROJECT_ID, location=LOCATION)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "960505627ddf"
      },
      "source": [
        "### Import libraries"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "PyQmSRbKA8r-"
      },
      "outputs": [],
      "source": [
        "import warnings\n",
        "\n",
        "warnings.filterwarnings(\"ignore\")\n",
        "\n",
        "import pandas as pd\n",
        "from sklearn.metrics import classification_report\n",
        "from vertexai.generative_models import GenerationConfig, GenerativeModel"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "UP76a2la7O-a"
      },
      "source": [
        "### Import model"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "7isig7e07O-a"
      },
      "outputs": [],
      "source": [
        "generation_model = GenerativeModel(\"gemini-2.0-flash\")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "RvHZ-R0umYYR"
      },
      "source": [
        "#### Generation config\n",
        "\n",
        "- Each call that you send to a model includes parameter values that control how the model generates a response. The model can generate different results for different parameter values\n",
        "- <strong>Experiment</strong> with different parameter values to get the best values for the task\n",
        "\n",
        "Refer to following [link](https://cloud.google.com/vertex-ai/generative-ai/docs/learn/prompt-design-strategies#experiment-with-different-parameter-values) for understanding different parameters"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "ZnyExfPTi68w"
      },
      "outputs": [],
      "source": [
        "generation_config = GenerationConfig(temperature=0.1, max_output_tokens=256)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "fIPcn5dZ7O-b"
      },
      "source": [
        "## Text Classification"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "l2eDjxvafo5W"
      },
      "source": [
        "In the section below, you will explore zero-shot prompting, few-shot prompting, and some common types of text classification tasks."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "bC3qkPZ8jFkY"
      },
      "source": [
        "### Zero-shot prompting"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "W8RFu2ZX_o_y"
      },
      "source": [
        "Zero-shot prompting is where you do not provide examples with labels, and rely on the LLM to make the classification on its own."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "_uNNGhC_e1nZ"
      },
      "outputs": [],
      "source": [
        "prompt = \"\"\"\n",
        "Classify the following:\\n\n",
        "text: \"I saw a furry animal in the park today with a long tail and big eyes.\"\n",
        "label: dogs, cats\n",
        "\"\"\"\n",
        "\n",
        "response = generation_model.generate_content(\n",
        "    contents=prompt, generation_config=generation_config\n",
        ").text\n",
        "print(response)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "tjl-tckTjK2B"
      },
      "source": [
        "### Few-shot prompting"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "5UC0w7n5_o_z"
      },
      "source": [
        "With few-shot prompting, you provide examples to the Gemini model and expect it to predict classes based on the provided examples."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "b8jYL-hBjMtW"
      },
      "outputs": [],
      "source": [
        "prompt = \"\"\"\n",
        "What is the topic for a given news headline? \\n\n",
        "- business \\n\n",
        "- entertainment \\n\n",
        "- health \\n\n",
        "- sports \\n\n",
        "- technology \\n\\n\n",
        "\n",
        "Text: Pixel 7 Pro Expert Hands On Review. \\n\n",
        "The answer is: technology \\n\n",
        "\n",
        "Text: Quit smoking? \\n\n",
        "The answer is: health \\n\n",
        "\n",
        "Text: Birdies or bogeys? Top 5 tips to hit under par \\n\n",
        "The answer is: sports \\n\n",
        "\n",
        "Text: Relief from local minimum-wage hike looking more remote \\n\n",
        "The answer is: business \\n\n",
        "\n",
        "Text: You won't guess who just arrived in Bari, Italy for the movie premiere. \\n\n",
        "The answer is:\n",
        "\"\"\"\n",
        "\n",
        "response = generation_model.generate_content(\n",
        "    contents=prompt, generation_config=generation_config\n",
        ").text\n",
        "print(response)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "WaiMLs4SjNLi"
      },
      "source": [
        "### Other classification examples"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "LhkcRWrh_o_0"
      },
      "source": [
        "Explore some more common text classification prompts below, which are all based on zero-shot prompts. You can also turn some of these into few-shot prompts by providing your own custom examples of text and the associated output classes."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "8tEjKEAtXjf8"
      },
      "source": [
        "#### Topic classification"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "bCnuQRVSXmyr"
      },
      "outputs": [],
      "source": [
        "prompt = \"\"\"\n",
        "Classify a piece of text into one of several predefined topics, such as sports, politics, or entertainment. \\n\n",
        "text: President Biden will be visiting India in the month of March to discuss a few opportunities. \\n\n",
        "class:\n",
        "\"\"\"\n",
        "\n",
        "response = generation_model.generate_content(\n",
        "    contents=prompt, generation_config=generation_config\n",
        ").text\n",
        "print(response)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "rB6rZD-6YAkC"
      },
      "source": [
        "####  Spam detection"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "OfyHvhBfX_P_"
      },
      "outputs": [],
      "source": [
        "prompt = \"\"\"\n",
        "Given an email, classify it as spam or not spam. \\n\n",
        "email: hi user, \\n\n",
        "      you have been selected as a winner of the lottery and can win upto 1 million dollar. \\n\n",
        "      kindly share your bank details and we can proceed from there. \\n\\n\n",
        "\n",
        "      from, \\n\n",
        "      US Official Lottry Depatmint\n",
        "\"\"\"\n",
        "\n",
        "response = generation_model.generate_content(\n",
        "    contents=prompt, generation_config=generation_config\n",
        ").text\n",
        "print(response)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "cHKcxx0TYrGv"
      },
      "source": [
        "#### Intent recognition"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "DsseGvWNYvK3"
      },
      "outputs": [],
      "source": [
        "prompt = \"\"\"\n",
        "Given a user's input, classify their intent, such as \"finding information\", \"making a reservation\", or \"placing an order\". \\n\n",
        "user input: Hi, can you please book a table for two at Juan for May 1?\n",
        "\"\"\"\n",
        "\n",
        "response = generation_model.generate_content(\n",
        "    contents=prompt, generation_config=generation_config\n",
        ").text\n",
        "print(response)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "stsfav5aZtqV"
      },
      "source": [
        "#### Language identification"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "88TqQSXIZxts"
      },
      "outputs": [],
      "source": [
        "prompt = \"\"\"\n",
        "Given a piece of text, classify the language it is written in. \\n\n",
        "text: Selam nasıl gidiyor?\n",
        "language:\n",
        "\"\"\"\n",
        "\n",
        "response = generation_model.generate_content(\n",
        "    contents=prompt, generation_config=generation_config\n",
        ").text\n",
        "print(response)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "3Z_jmrhOZ15J"
      },
      "source": [
        "#### Toxicity detection"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "Umloy-o1Z5us"
      },
      "outputs": [],
      "source": [
        "prompt = \"\"\"\n",
        "Given a piece of text, classify it as toxic or non-toxic. \\n\n",
        "text: i love sunny days\n",
        "\"\"\"\n",
        "\n",
        "response = generation_model.generate_content(\n",
        "    contents=prompt, generation_config=generation_config\n",
        ").text\n",
        "print(response)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "rTH5MeiVZ6Hr"
      },
      "source": [
        "#### Emotion detection"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "u5ETwBSrZ-Xn"
      },
      "outputs": [],
      "source": [
        "prompt = \"\"\"\n",
        "Given a piece of text, classify the emotion it conveys, such as happiness, or anger. \\n\n",
        "text: I'm still so delighted from yesterday's news\n",
        "\"\"\"\n",
        "\n",
        "response = generation_model.generate_content(\n",
        "    contents=prompt, generation_config=generation_config\n",
        ").text\n",
        "print(response)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "7ddaadac64c7"
      },
      "source": [
        "### Evaluation"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "d5e2266cb257"
      },
      "source": [
        "You can evaluate the outputs of the text classification task if the ground truth classes are available. To showcase how this works, start by creating a simple dataframe with product reviews and the ground truth sentiment."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "b0e04a03f24f"
      },
      "outputs": [],
      "source": [
        "review_data_df = pd.DataFrame(\n",
        "    {\n",
        "        \"review\": [\n",
        "            \"i love this product. it does have everything i am looking for!\",\n",
        "            \"all i can say is that you will be happy after buying this product\",\n",
        "            \"its way too expensive and not worth the price\",\n",
        "            \"i am feeling okay. its neither good nor too bad.\",\n",
        "        ],\n",
        "        \"sentiment_groundtruth\": [\"positive\", \"positive\", \"negative\", \"neutral\"],\n",
        "    }\n",
        ")\n",
        "\n",
        "review_data_df"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "088327f41a26"
      },
      "source": [
        "Now that you have the data with reviews and sentiments as ground truth labels, you can call the text generation model to each review row using the `apply` function. Each row will use the prompt in the `review` column to predict the sentiment using the Gemini API, and store the results in `sentiment_prediction` column.  "
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "0fb691b6c721"
      },
      "outputs": [],
      "source": [
        "def get_sentiment(row):\n",
        "    prompt = f\"\"\"Classify the sentiment of the following review as \"positive\", \"neutral\" and \"negative\". \\n\\n\n",
        "                review: {row} \\n\n",
        "                sentiment:\n",
        "              \"\"\"\n",
        "    response = generation_model.generate_content(\n",
        "        contents=prompt, generation_config=generation_config\n",
        "    ).text\n",
        "    return response\n",
        "\n",
        "\n",
        "review_data_df[\"sentiment_prediction\"] = review_data_df[\"review\"].apply(get_sentiment)\n",
        "review_data_df"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "908c72bdf4c7"
      },
      "source": [
        "In the end, you can call the `classification_report` function from sklearn to measure the accuracy and other classification metrics by passing ground truth sentiments `sentiment_groundtruth` and predicted sentiment `sentiment_prediction`:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "7f22690ae395"
      },
      "outputs": [],
      "source": [
        "report = classification_report(\n",
        "    review_data_df[\"sentiment_groundtruth\"], review_data_df[\"sentiment_prediction\"]\n",
        ")\n",
        "print(report)"
      ]
    }
  ],
  "metadata": {
    "colab": {
      "name": "text_classification.ipynb",
      "toc_visible": true
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}
